import json

import torch

from ModularUtils.ControllerConstants import map_dictfill_to_discrete, generate_permutations
from ModularUtils.ControllerModel import get_generated_labels
from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FunctionsConstant import getdoKey, asKey
from ModularUtils.FunctionsDistribution import get_joint_distributions_from_samples, calculate_TVD
from Sachs_experiment.GroundTruth.CausalGraph_Sachs import set_sachs_nonId_graph
from Train_By_Components.Synthetic_TrainByComp import get_intv_dist


# ara= torch.randn(10)*4
# print(ara)
# ret= map_to_exact_discrete(ara)
# print(ret)




def get_sachs_cf_dist(Exp, label_generators):

    result = {"obs_comb": [], "prob":[], "loss": []}

    obs_vars=["X1", "X2","W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]

    feat = "feature"
    cfquery = Exp.cf_queries[0]

    if bool(set(cfquery["obs"]) & set(cur_mechs)) == False:
        return tvd_diff, kl_diff

    evidence_vars = [Exp.twin_map[lb] for lb in cfquery["evidence"][0].keys()]
    compare_Var = list(evidence_vars)  # getting the intervened variables
    query_str = getdoKey(compare_Var, dict({}))  # getting the scm saving file name
    obs_dist = get_intv_dist(Exp, compare_Var, dict({}), query_str)  # getting the obs distribution of intv variables

    final_tvd = 0
    final_kl = 0

    n_samples = Exp.Synthetic_Sample_Size

    evidence_list = [evidence for evidence in cfquery["evidence"]]
    all_posterior_label, all_posterior_latent, all_gumbel_noise = rejection_sampling_optimized(Exp, label_generators,
                                                                                               n_samples, evidence_list,
                                                                                               max_rejections=0,
                                                                                               warn=100)


    evidence_list=[{"X1p":0, "X2p":0}, {"X1p":0, "X2p":1}, {"X1p":0, "X2p":2},
                   {"X1p":1, "X2p":6}, {"X1p":1, "X2p":7}, {"X1p":1, "X2p":8}]

    intv_dict=[{0:[3,5]}, {1:[3,5]}, {2:[4,7]},
               {3:[1,3]}, {4:[1,3]}, {5:[6,7]}]



    kev = asKey(evidence)
    posterior_label, posterior_latent, gumbel_noise = all_posterior_label[kev], all_posterior_latent[kev], all_gumbel_noise[kev]

    cf_all_labels_dict = get_generated_labels(Exp, label_generators, posterior_label, posterior_latent,
                                              intv_key, cfquery["obs"], n_samples, gumbel_noise=gumbel_noise)
    cf_samples = map_dictfill_to_discrete(Exp, cf_all_labels_dict, cfquery["obs"])

    upd_dist = get_joint_distributions_from_samples(Exp, observed_var, samples, feature)

    # true_cf_dist = get_cf_dist(Exp, cfquery["obs"], intv_key, evidence, cfquery["expr"], load_dist=True)

    print(f"CF query done for evidence:{evidence}, intv_key: {intv_key} ")


    file_name ="/path_to_project/SAVED_EXPERIMENTS/nonid_mnist_images/labels/P(Y|do(X1,X2),X1',X2').txt"
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(result))

    return


def get_sachs_interventional_dist(Exp, label_generators):

    obs_vars=["Erk"]
    intv_vars=["Mek"]
    perms = generate_permutations([Exp.label_dim[lb]["feature"] for lb in intv_vars])
    key_vals = [dict(zip(intv_vars, comb)) for comb in perms]

    result = {"obs_comb": [], "loss": []}
    for intv_key in key_vals:
        generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, obs_vars, Exp.Synthetic_Sample_Size)
        generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, obs_vars)


        upd_dist= get_joint_distributions_from_samples(Exp, obs_vars, generated_labels_full, "feature")


    return


Exp = Experiment("Exp1", set_sachs_nonId_graph,
                 dist_thresh=0.15,
                 Temperature=1,
                 temp_min=0.1,
                 G_hid_dims=[256, 256],
                 D_hid_dims=[256, 256],
                 CRITIC_ITERATIONS=5,
                 LAMBDA_GP=10,
                 learning_rate=2 * 1e-4,
                 # learning_rate= 1e-5,
                 Synthetic_Sample_Size=10000,
                 intv_Sample_Size=10000,
                 batch_size=100,
                 features=["feature"],
                 noise_states=100,
                 latent_state=16,
                 # Data_intervs=[{}, {"PKA": 2}],
                 # Data_intervs=[{}, {"PKC": 1}, {"PKC": 2}],
                 # Data_intervs=[{"Mek": 0}],
                 Data_intervs=[{}],
                 num_epochs=500,
                 obs_state=3,
                 new_experiment=False
                 )

Exp.intv_batch_size = Exp.batch_size


SHARED_INFO = "/path_to_project/SAVED_EXPERIMENTS/"+Exp.Complete_DAG_desc+"/SHARED_INFO.txt"
with open(SHARED_INFO) as f:
    data = f.read()
INSTANCE = json.loads(data)

# last_exp = INSTANCE["last_exp"]
last_exp = "/path_to_project/SAVED_EXPERIMENTS/sachs_nonId_graph/Exp1/Oct_10_2022-02_23"
print(last_exp)
Exp.LOAD_MODEL_PATH = last_exp

Exp.load_which_models = {"PKA": True, "Mek": True, "Erk": True, "Akt":True}
label_generators, optimizersMech = get_generators(Exp, Exp.load_which_models)

for gen in label_generators:
    label_generators[gen].eval()


with torch.no_grad():
    # ret = get_do_dist(["Erk"], "PKC")
    # print(ret)

    feat = "feature"
    cur_mechs = ["PKA", "Mek", "Erk", "Akt"]

    compare_Var =  ["Akt"]
    intv_key={"PKA":2}
    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, compare_Var,
                                                 Exp.Synthetic_Sample_Size)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, compare_Var)

    upd_dist= get_joint_distributions_from_samples(Exp, compare_Var, generated_labels_full, "feature")

    print(upd_dist)

    true_dist = {tuple([0]): 0.8046, tuple([1]): 0.19538, tuple([2]): 1e-6}

    tvd= calculate_TVD(true_dist, upd_dist, doPrint=False)
    print(tvd)
    # get_sachs_interventional_dist(Exp, label_generators)









